#include "grain_growth.hpp"
#include "cxxopts.hpp"

#include <ctime>

#include <chrono>

bool debug_on = false;


valueType *read_init_state(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    valueType *mtx = new valueType[Nx * Ny * n_grains];
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint x = 0; x < Nx; ++x)
        {
            fscanf(inp, "%lf", mtx + pg * Nx * Ny + x * Ny);
            for (uint y = 1; y < Ny; ++y)
                fscanf(inp, ",%lf", mtx + pg * Nx * Ny + x * Ny + y);
        }
    }

    fclose(inp);
    return mtx;
}

void print_img_in_csv(valueType *img, const char *filename, uint Nx, uint Ny,
                      uint n_grains)
{
    FILE *oup = fopen(filename, "w");
    fprintf(oup, "%u,%u,%u\n", Nx, Ny, n_grains);
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint i = 0; i < Nx; ++i)
        {
            fprintf(oup, "%lf", img[pg * Nx * Ny + i * Ny]);
            for (uint j = 1; j < Ny; ++j)
            {
                fprintf(oup, ",%lf", img[pg * Nx * Ny + i * Ny + j]);
            }
            fputc('\n', oup);
        }
    }
    fclose(oup);
}


class Args
{
public:
    uint nsteps;
    string input;
    string output;
    string bucket_output;
    uint lshL, lshK; // Nx, Ny;
    double lshr;
};

Args *parse_args(int argc, const char *argv[])
{
    try
    {
        Args *args = new Args;

        cxxopts::Options options(argv[0], " - test forward simulation of grain growth.");
        options
            .positional_help("[optional args]")
            .show_positional_help();

        options
            .set_width(70)
            .set_tab_expansion()
            .allow_unrecognised_options()
            .add_options()("s,nsteps", "Number of steps of simulation (default=100)", cxxopts::value<int>(), "N")("o,output", "Output file (default=grain.out)", cxxopts::value<std::string>(), "FILE")("i,input", "Input file (default=grain.in)", cxxopts::value<std::string>(), "FILE")("lshK", "K for LSH (default=1)", cxxopts::value<int>(), "INT")("lshL", "L for LSH (default=1)", cxxopts::value<int>(), "INT")("lshr", "r for LSH (default=1e-4)", cxxopts::value<float>(), "FLOAT")("bucket_output", "Output file of the bucket information (default=bucket.out)", cxxopts::value<std::string>(), "FILE")("h,help", "Print help")
#ifdef CXXOPTS_USE_UNICODE
                ("unicode", u8"A help option with non-ascii: à. Here the size of the"
                            " string should be correct")
#endif
            ;
        //("Nx", "size of x-axis (default=64)", cxxopts::value<int>(), "INT")
        //("Ny", "size of y-axis (default=64)", cxxopts::value<int>(), "INT")

        auto result = options.parse(argc, argv);

        if (result.count("help"))
        {
            std::cout << options.help({"", "Group"}) << std::endl;
            exit(0);
        }

        std::cout << "[Parse Args]" << std::endl;

        if (result.count("nsteps"))
        {
            std::cout << "  nsteps = " << result["nsteps"].as<int>() << std::endl;
            args->nsteps = (uint)result["nsteps"].as<int>();
        }
        else
        {
            args->nsteps = 100;
        }

        if (result.count("output"))
        {
            std::cout << "  output = " << result["output"].as<std::string>()
                      << std::endl;
            args->output = result["output"].as<std::string>();
        }
        else
        {
            args->output = "grain.out";
        }

        if (result.count("input"))
        {
            std::cout << "  input = " << result["input"].as<std::string>()
                      << std::endl;
            args->input = result["input"].as<std::string>();
        }
        else
        {
            args->input = "grain.in";
        }

        if (result.count("bucket_output"))
        {
            std::cout << "  bucket_output = " << result["bucket_output"].as<std::string>()
                      << std::endl;
            args->bucket_output = result["bucket_output"].as<std::string>();
        }
        else
        {
            args->bucket_output = "bucket.out";
        }
        if (result.count("lshK"))
        {
            std::cout << "  lshK = " << result["lshK"].as<int>()
                      << std::endl;
            args->lshK = (uint)result["lshK"].as<int>();
        }
        else
        {
            args->lshK = 1;
        }

        if (result.count("lshL"))
        {
            std::cout << "  lshL = " << result["lshL"].as<int>()
                      << std::endl;
            args->lshL = (uint)result["lshL"].as<int>();
        }
        else
        {
            args->lshL = 1;
        }

        if (result.count("lshr"))
        {
            std::cout << "  lshr = " << result["lshr"].as<float>()
                      << std::endl;
            args->lshr = (double)result["lshr"].as<float>();
        }
        else
        {
            args->lshr = 1e-4;
        }

        auto arguments = result.arguments();
        std::cout << "  Saw " << arguments.size() << " arguments" << std::endl;

        std::cout << "[End of Parse Args]" << std::endl;

        /*
    if (result.count("Nx"))
    {
      std::cout << "  Nx = " << result["Nx"].as<int>()
        << std::endl;
      args->Nx = (uint)result["Nx"].as<int>();
    }else{
      args->Nx = 64;
    }
    if (result.count("Ny"))
    {
      std::cout << "  Ny = " << result["Ny"].as<int>()
        << std::endl;
      args->Ny = (uint)result["Ny"].as<int>();
    }else{
      args->Ny = 64;
    }
    */

        return args;
    }
    catch (const cxxopts::OptionException &e)
    {
        std::cout << "error parsing options: " << e.what() << std::endl;
        exit(1);
    }
}

int main(int argc, const char *argv[])
{

    Args *args = parse_args(argc, argv);

    // def parameters
    uint Nx = 64; //1024;   these will be changed later.
    uint Ny = 64; //1024;
    uint n_grains = 2;
    uint n_step = 500;

    uint lshK = args->lshK;
    uint lshL = args->lshL;
    valueType lsh_r = args->lshr;
    uint nsteps = args->nsteps;

    valueType h = 0.5;

    valueType A = 1.0;
    valueType B = 1.0;
    valueType L = 5.0;
    valueType kappa = 0.1;

    valueType dtime = 0.05;
    valueType ttime = 0.0;

    // // lsh-smile learned parameter
    // valueType init_L = 11.6504; // try to learn to 5.0
    // valueType init_A = 1.98483; // try to learn to 1.0
    // valueType init_B = 2.01454; // try to learn to 1.0
    // valueType init_kappa = -0.0834962; // try to learn to 0.1
    
    // torch learned parameter
    valueType init_L = 9.0462; // try to learn to 5.0
    valueType init_A = 1.9431; // try to learn to 1.0
    valueType init_B = 8.8226; // try to learn to 1.0
    valueType init_kappa = 0.8431; // try to learn to 0.1

    double lr = 1e-5;
    uint start_skip = 1;
    uint skip_step = 30;
    // uint skip_step = 5;
    uint epoch = 500;

    char* data_path = "../grain_growth_all_data_1";
    GrainGrowthDataset dataset(data_path, start_skip, skip_step);

    Nx = dataset.Nx;
    Ny = dataset.Ny;
    n_grains = dataset.n_grains;
    n_step = dataset.n_step;

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }

    double min_loss = 1000.0;

    // save the original frame for both ground truth and lsh-smile
    ReturnItem start_item = dataset.get_item(start_skip);

    valueType* original_first_frame = start_item.data.eta1_eta2;

    print_img_in_csv(original_first_frame, "../output_torch/ref_start.out", Nx, Ny, n_grains);
    print_img_in_csv(original_first_frame, "../output_torch/sim_start.out", Nx, Ny, n_grains);

    for (int index = start_skip; index < dataset.get_len() - 10; index+=dataset.skip_step) {
        ReturnItem rt = dataset.get_item(index);

        valueType* eta1_eta2_start = rt.data.eta1_eta2;
        valueType lshr = 0.01;
        uint lshK = 3;
        uint lshL = 10;
        int img_size = dataset.Nx;
        valueType h = 0.5;
        valueType dtime = 0.05;
        valueType ttime = 0.0;
        uint eta1_eta2_len = img_size * img_size * n_grains;

        // if (debug_on) {
        //     std::cout << "sum of eta1_eta2_start: " << sum_mtx(eta1_eta2_start, eta1_eta2_len) << std::endl;
        // }

        GrainGrowthOneStep one_step(img_size, img_size, n_grains, lshK, lshL, h,\
                                    init_A, init_B, init_L, init_kappa, dtime, lshr);
        one_step.encode_from_img(eta1_eta2_start);

        auto start = std::chrono::high_resolution_clock::now();
        for (int j = 0; j < skip_step; ++j) {
            one_step.next();
            // std::cout << "sim step: " << j << std::endl;
        }
        auto stop = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
        std::cout << "time of ts model forward: " << duration.count() << "ms in " << skip_step << "steps" << std::endl;

        if (debug_on) {
            std::cout << "success forward in ts model" << std::endl;
        }

        // save lsh-smile simulated result
        valueType* eta1_eta2_sim = one_step.decode_to_img();
        int index_ref = index + skip_step;
        string sim_filename = std::string("../output_torch/sim_") + std::to_string(index_ref) + string(".out");
        print_img_in_csv(eta1_eta2_sim, sim_filename.c_str(), Nx, Ny, n_grains);

        // if (debug_on) {
        //     std::cout << "sum of eta1_eta2_sim: " << sum_mtx(eta1_eta2_sim, eta1_eta2_len) << std::endl;
        // }

        // save ground truth result
        valueType* eta1_eta2_ref = rt.ref.eta1_eta2_ref;
        string ref_filename = std::string("../output_torch/ref_") + std::to_string(index_ref) + string(".out");
        print_img_in_csv(eta1_eta2_ref, ref_filename.c_str(), Nx, Ny, n_grains);

        // if (debug_on) {
        //     std::cout << "sum of eta1_eta2_ref: " << sum_mtx(eta1_eta2_ref, eta1_eta2_len) << std::endl;
        // }

        
        delete eta1_eta2_sim;
    }
    
    printf("Have a nice day!\n");
    return 0;
}
